{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy cohere-command-r Model Package from AWS Marketplace \n",
    "\n",
    "\n",
    "Cohere builds a collection of Large Language Models (LLMs) trained on a massive corpus of curated web data. Powering these models, our infrastructure enables our product to be deployed for a wide range of use cases. The use cases we power include generation (copy writing, etc), summarization, classification, content moderation, information extraction, semantic search, and contextual entity extraction\n",
    "\n",
    "This sample notebook shows you how to deploy [cohere-command-r: ml.p4de.24xlarge](https://aws.amazon.com/marketplace/pp/prodview-w7ukdez7zfjfo) or [cohere-command-r: ml.p5.48xlarge](https://aws.amazon.com/marketplace/pp/prodview-jfhyfeewxqbr2) using Amazon SageMaker.\n",
    "\n",
    "> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.\n",
    "\n",
    "> cohere-command model package support SageMaker Realtime Inference but not SageMaker Batch Transform.\n",
    "\n",
    "## Pre-requisites:\n",
    "1. **Note**: This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.\n",
    "1. Ensure that IAM role used has **AmazonSageMakerFullAccess**\n",
    "1. To deploy this ML model successfully, ensure that:\n",
    "    1. Either your IAM role has these three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used: \n",
    "        1. **aws-marketplace:ViewSubscriptions**\n",
    "        1. **aws-marketplace:Unsubscribe**\n",
    "        1. **aws-marketplace:Subscribe**  \n",
    "    2. or your AWS account has a subscription to [cohere-command-r: ml.p4de.24xlarge](https://aws.amazon.com/marketplace/pp/prodview-w7ukdez7zfjfo) or [cohere-command-r: ml.p5.48xlarge](https://aws.amazon.com/marketplace/pp/prodview-jfhyfeewxqbr2). If so, skip step: [Subscribe to the model package](#1.-Subscribe-to-the-model-package)\n",
    "\n",
    "## Contents:\n",
    "1. [Subscribe to the model package](#1.-Subscribe-to-the-model-package)\n",
    "2. [Create an endpoint and perform real-time inference](#2.-Create-an-endpoint-and-perform-real-time-inference)\n",
    "   1. [Create an endpoint](#A.-Create-an-endpoint)\n",
    "   2. [Create input payload](#B.-Create-input-payload)\n",
    "   3. [Perform real-time inference](#C.-Perform-real-time-inference)\n",
    "   4. [Visualize output](#D.-Visualize-output)\n",
    "   5. [Streaming Chat](#E.-Streaming-Chat)\n",
    "   6. [Chat with documents (RAG)](#F.-Chat-with-documets-(RAG))\n",
    "   7. [Generate search queries](#G.-generate-search-queries)\n",
    "   8. [Tool inputs](#H.-Tool-inputs)\n",
    "   9. [Tool results](#I.-Tool-results)\n",
    "3. [Clean-up](#4.-Clean-up)\n",
    "    1. [Delete the model](#A.-Delete-the-model)\n",
    "    2. [Unsubscribe to the listing (optional)](#B.-Unsubscribe-to-the-listing-(optional))\n",
    "    \n",
    "\n",
    "## Usage instructions\n",
    "You can run this notebook one cell at a time (By using Shift+Enter for running a cell)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Subscribe to the model package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the instance_type depeding on which hardware you want to use\n",
    "\n",
    "instance_type = 'ml.p4de.24xlarge'\n",
    "# instance_type = 'ml.p5.48xlarge'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To subscribe to the model package:\n",
    "\n",
    "1. Open the model package listing page [cohere-command-r: ml.p4de.24xlarge](https://aws.amazon.com/marketplace/pp/prodview-w7ukdez7zfjfo) or [cohere-command-r: ml.p5.48xlarge](https://aws.amazon.com/marketplace/pp/prodview-jfhyfeewxqbr2)\n",
    "1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.\n",
    "1. On the **Subscribe to this software** page, review and click on **\"Accept Offer\"** if you and your organization agrees with EULA, pricing, and support terms. \n",
    "1. Once you click on **Continue to configuration button** and then choose a **region**, you will see a **Product Arn** displayed. This is the model package ARN that you need to specify while creating a deployable model using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade cohere-aws\n",
    "# if you upgrade the package, you need to restart the kernel\n",
    "\n",
    "from cohere_aws import Client\n",
    "import boto3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cohere_packages = {\n",
    "    'ml.p4de.24xlarge': \"cohere-command-r-a100-v1-1\",\n",
    "    'ml.p5.48xlarge': \"cohere-command-r-h100-v1-1\",\n",
    "}\n",
    "\n",
    "cohere_package = cohere_packages[instance_type]\n",
    "\n",
    "# Mapping for Model Packages\n",
    "model_package_map = {\n",
    "    \"us-east-1\": f\"arn:aws:sagemaker:us-east-1:865070037744:model-package/{cohere_package}\",\n",
    "    \"us-east-2\": f\"arn:aws:sagemaker:us-east-2:057799348421:model-package/{cohere_package}\",\n",
    "    \"us-west-1\": f\"arn:aws:sagemaker:us-west-1:382657785993:model-package/{cohere_package}\",\n",
    "    \"us-west-2\": f\"arn:aws:sagemaker:us-west-2:594846645681:model-package/{cohere_package}\",\n",
    "    \"ca-central-1\": f\"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{cohere_package}\",\n",
    "    \"eu-central-1\": f\"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{cohere_package}\",\n",
    "    \"eu-west-1\": f\"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{cohere_package}\",\n",
    "    \"eu-west-2\": f\"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{cohere_package}\",\n",
    "    \"eu-west-3\": f\"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{cohere_package}\",\n",
    "    \"eu-north-1\": f\"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{cohere_package}\",\n",
    "    \"ap-southeast-1\": f\"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{cohere_package}\",\n",
    "    \"ap-southeast-2\": f\"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{cohere_package}\",\n",
    "    \"ap-northeast-2\": f\"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{cohere_package}\",\n",
    "    \"ap-northeast-1\": f\"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{cohere_package}\",\n",
    "    \"ap-south-1\": f\"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{cohere_package}\",\n",
    "    \"sa-east-1\": f\"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{cohere_package}\",\n",
    "}\n",
    "\n",
    "region = boto3.Session().region_name\n",
    "if region not in model_package_map.keys():\n",
    "    raise Exception(f\"Current boto3 session region {region} is not supported.\")\n",
    "\n",
    "model_package_arn = model_package_map[region]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create an endpoint and perform real-time inference"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to understand how real-time inference with Amazon SageMaker works, see [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A. Create an endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "co = Client(region_name=region)\n",
    "co.create_endpoint(arn=model_package_arn, endpoint_name=\"cohere-command-r\", instance_type=instance_type, n_instances=1)\n",
    "\n",
    "# If the endpoint is already created, you just need to connect to it\n",
    "# co.connect_to_endpoint(endpoint_name=\"cohere-command-r\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once endpoint has been created, you would be able to perform real-time inference."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Create input payload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = \"Write a LinkedIn post about starting a career in tech:\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### C. Perform real-time inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = co.chat(message=message, stream=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### D. Visualize output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(response)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### E. Streaming Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = \"Write a LinkedIn post about starting a career in tech:\"\n",
    "\n",
    "response = co.chat(message=message, stream=True)\n",
    "\n",
    "# stream events back\n",
    "for res in response:\n",
    "    print(res)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### F. Chat with documents (RAG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message=\"How deep in the Mariana Trench\"\n",
    "documents = [\n",
    "    {\n",
    "       \"id\": \"national_geographic_everest\",\n",
    "       \"title\": \"Height of Mount Everest\",\n",
    "       \"snippet\": \"The height of Mount Everest is 29,035 feet\",\n",
    "       \"url\": \"https://education.nationalgeographic.org/resource/mount-everest/\",\n",
    "    },\n",
    "    {\n",
    "        \"id\": \"national_geographic_mariana\",\n",
    "        \"title\": \"Depth of the Mariana Trench\",\n",
    "        \"snippet\": \"The depth of the Mariana Trench is 36,070 feet\",\n",
    "        \"url\": \"https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth\",\n",
    "    }\n",
    "]\n",
    "\n",
    "response = co.chat(message=message, documents=documents, stream=False)\n",
    "\n",
    "print(response)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### G. Generate search queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message=\"What is the height of Mount Everest?\"\n",
    "\n",
    "response = co.chat(message=message, search_queries_only=True, stream=False)\n",
    "\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### H. Tool inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message=\"What were sales like on October 31st?\"\n",
    "tools=[\n",
    "    {\n",
    "        \"name\": \"sales_database\",\n",
    "        \"description\": \"Connects to a database about sales volumes\",\n",
    "        \"parameter_definitions\": {\n",
    "            \"day\": {\n",
    "                \"description\": \"Retrieves sales data from this day, formatted as YYYY-MM-DD\",\n",
    "                \"type\": \"str\",\n",
    "                \"required\": True\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "]\n",
    "\n",
    "response = co.chat(message=message, tools=tools, stream=False)\n",
    "\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### I. Tool results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message=\"What were sales like on October 31st?\"\n",
    "tools=[\n",
    "    {\n",
    "        \"name\": \"sales_database\",\n",
    "        \"description\": \"Connects to a database about sales volumes\",\n",
    "        \"parameter_definitions\": {\n",
    "            \"day\": {\n",
    "                \"description\": \"Retrieves sales data from this day, formatted as YYYY-MM-DD\",\n",
    "                \"type\": \"str\",\n",
    "                \"required\": True\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "]\n",
    "tool_results=[\n",
    "    {\n",
    "        \"call\": {\n",
    "            \"name\": \"sales_database\",\n",
    "            \"parameters\": {\n",
    "                \"day\": \"2023-04-08\"\n",
    "            }\n",
    "        },\n",
    "        \"outputs\": [{\"number_of_sales\": 120, \"total_revenue\": 48500, \"day\": \"2023-04-08\"}]\n",
    "    }\n",
    "]\n",
    "\n",
    "response = co.chat(message=message, tools=tools, tool_results=tool_results, stream=False)\n",
    "\n",
    "print(response)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Clean-up"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A. Delete the endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "co.delete_endpoint()\n",
    "co.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Unsubscribe to the listing (optional)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you would like to unsubscribe to the model package, follow these steps. Before you cancel the subscription, ensure that you do not have any [deployable model](https://console.aws.amazon.com/sagemaker/home#/models) created from the model package or using the algorithm. Note - You can find this information by looking at the container name associated with the model. \n",
    "\n",
    "**Steps to unsubscribe to product from AWS Marketplace**:\n",
    "1. Navigate to __Machine Learning__ tab on [__Your Software subscriptions page__](https://aws.amazon.com/marketplace/ai/library?productType=ml&ref_=mlmp_gitdemo_indust)\n",
    "2. Locate the listing that you want to cancel the subscription for, and then choose __Cancel Subscription__  to cancel the subscription.\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
